##########################################
# utils_direction.py
##########################################
import torch
import networkx as nx
import dgl

##########################################
# utils_direction.py
##########################################
import torch
import networkx as nx
import dgl
import torch.nn.functional as F

def attach_direction_mask(g):
    """
    Computes a directional mask for a DGLGraph based on PageRank.
    For each edge (u,v), if PageRank(u) > PageRank(v) (i.e. u is upstream), set mask=1, else 0.
    The mask is attached to g.edata['mask'].
    """
    nx_g = g.to_networkx()
    pr = nx.pagerank(nx_g)
    src, dst = g.edges()
    src_list = src.tolist()
    dst_list = dst.tolist()
    mask = []
    for u, v in zip(src_list, dst_list):
        mask.append(1.0 if pr.get(u, 0) > pr.get(v, 0) else 0.0) 
    mask = torch.tensor(mask, dtype=torch.float).unsqueeze(1)
    if g.device.type == 'cuda':
        mask = mask.cuda()
    g.edata['mask'] = mask
    return g

def update_direction_mask(g, x_star):
    """
    Updates the edge mask of graph `g` based on the current predicted flows `x_star`.
    This function constructs a weighted NetworkX graph using the absolute values of x_star 
    as edge weights, computes weighted PageRank, and sets the mask for each edge to 1 if the 
    source's PageRank is higher than the destination's, and 0 otherwise.
    
    Parameters:
      g: DGLGraph.
      x_star: torch.Tensor of shape (num_edges, 1) containing current flow predictions.
      
    Returns:
      g: DGLGraph with updated edge feature 'mask'.
    """
    # Convert DGL graph to NetworkX graph.
    nx_g = g.to_networkx()
    src, dst = g.edges()
    x_vals = x_star.squeeze().detach().cpu().numpy()  # shape: (num_edges,)
    
    # Build a dictionary mapping each edge to a weight (using absolute predicted flow).
    weights = {}
    src_list = src.tolist()
    dst_list = dst.tolist()
    for i, (u, v) in enumerate(zip(src_list, dst_list)):
        weights[(u, v)] = abs(x_vals[i]) if x_vals[i] is not None else 1.0

    # Assign the weight attribute to each edge in the NetworkX graph.
    for u, v in nx_g.edges():
        nx_g[u][v]['weight'] = weights.get((u, v), 1.0)
    
    # Compute weighted PageRank.
    pr = nx.pagerank(nx_g, weight='weight')
    
    # Update the mask: For each edge (u, v), set mask=1 if pr(u) > pr(v), else 0.
    new_mask = []
    for u, v in zip(src_list, dst_list):
        new_mask.append(1.0 if pr.get(u, 0) > pr.get(v, 0) else 0.0)
    new_mask = torch.tensor(new_mask, dtype=torch.float).unsqueeze(1)
    if g.device.type == 'cuda':
        new_mask = new_mask.cuda()
    g.edata['mask'] = new_mask
    return g
